# -*- coding: utf-8 -*-
"""
Created on Sat Apr  1 15:09:05 2023

@author: 29672366
"""
import os
import torch
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
from models.QAModel import QAData, QAModel
import torch.nn as nn
import torch.nn.functional as F

from utils.DigitEmbedding import DigitEmbedding
from utils.SelfAttentionEncoder import SelfAttentionEncoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = 'cpu'
torch.autograd.set_detect_anomaly(True)

traindata_path = "./traindatas/QA/train"
valdata_path = "./traindatas/QA/val"

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

max_seq_length = 500
batch_size = 1#训练多轮对话 1个batch够了 设置QAData的historynum即可
num_digits = 6
historynum = 5

train_data = QAData(traindata_path, tokenizer, max_seq_length, num_digits = num_digits,historynum=historynum)
val_data = QAData(valdata_path, tokenizer, max_seq_length, num_digits = num_digits, historynum=historynum)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, drop_last=True)

model = QAModel(max_seq_length=max_seq_length,
                input_max_num_sequences=batch_size*historynum,
                num_digits=num_digits)

model = model.to(device)


model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5,weight_decay=1e-5)

os.makedirs("./run", exist_ok=True)

# 添加 ModelCheckpoint 回调
best_val_loss = float('inf')
checkpoint_path = './run/best_model.pt'

if os.path.exists(checkpoint_path):
    print('load checkpoint')
    model.load(checkpoint_path)

def calculate_loss(output, target):
    token_dict = {}

    for i, token in enumerate(output):
        if token not in token_dict:
            token_dict[token] = []
        token_dict[token].append(i)

    prev_index = 0
    loss = torch.tensor(0.0)
    total_length = len(target)
    truncated_length = 0
    
    for token, indices in token_dict.items():
        for index in indices:
            if index >= prev_index:
                # 截断部分
                mask = torch.ones_like(target).bool()
                mask[:prev_index+1] = False
                mask[index:] = False
                target_part = target[mask]
                truncated_length += len(target_part)
                prev_index = index

    # 计算截断比例
    if total_length > 0:
        truncated_ratio = truncated_length / total_length
        loss /= truncated_ratio

    return loss


def replace_tokenizer_code(inputs, output, target):
    batch_size = inputs.size(0)
    seq_len = inputs.size(1)
    num_digits = inputs.size(2)
    loop = False
    index = -1
    # 取最后一个batch进行计算
    last_batch_inputs = inputs[-1].to(inputs.device)
    last_batch_target = target[-1].to(inputs.device)
    outputval = torch.argmax(output,dim=1).to(inputs.device)
    #print(last_batch_target[0])
    #print("outputval:",outputval)
    zeroval = torch.zeros(num_digits,dtype=torch.int64).to(inputs.device)
    #print("zeroval:",zeroval)

    # 在target中查找output的对应位置，将这个位置之前的所有数字全部删除
    for i in range(last_batch_target.size(0)):
        if last_batch_target[i].equal(zeroval):
            index = i
            #print('zero stop,index:',index)
            loop = False
            break
        if torch.eq(last_batch_target[i], outputval).all():
            index = i
            loop = True
            print(" ###########index :",index,'outputval:',outputval)
            print('last_batch_target[',index,']:',last_batch_target[i])
            
            # 遍历inputs，找到第一个包含num_digits个0的位置
            for i in range(seq_len):
                if torch.eq(last_batch_inputs[i], zeroval).all():
                    # 将output的num_digits位数字替换进去
                    last_batch_inputs[i] = outputval
                    break
            if not torch.eq(last_batch_inputs[-1], zeroval).all():
                if not torch.eq(last_batch_inputs[-1], outputval).all():
                    torch.roll(inputs,shift=-1,dims=0)
                    inputs[-1] = torch.zeros_like(inputs[-1]).to(inputs.device)
                    inputs[-1][0] = outputval
            break

    if loop == True:
        last_batch_target = torch.cat((last_batch_target[index+1:],torch.zeros_like(last_batch_target[last_batch_target.size(0)-index-1:])),dim=0)
        
        
    # 将修改后的最后一个batch放回到原来的tensor中
    inputs[-1] = last_batch_inputs
    target[-1] = last_batch_target

    return loop, inputs, target, index



crossentropy = nn.CrossEntropyLoss()

for epoch in range(10000):
    # 训练
    model.train()
    for batch_idx, batch in enumerate(train_loader):
        inputs, target = batch
        inputs = inputs.view(-1, inputs.shape[2], inputs.shape[3])
        target = target.view(-1, target.shape[2], target.shape[3])
        inputs = inputs.to(device)
        target = target.to(device)
        index = 0
        loop = True
        #print("start:",target[-1])
        while(loop):
            optimizer.zero_grad()
            output = model(inputs)
            #print("output:",output)
            #print("inputs,shape:",inputs.shape,"\noutput.shape:",output.shape,"\ntarget.shape:",target.shape)
            loop, inputs, target, index = replace_tokenizer_code(inputs,output,target)
            
            #print(target[-1])
        # 计算loss，循环的时候是正确的 不用计算Loss，出来的时候计算第一个 因为target不断截断前移
        #print('target[-1,0]:',target[-1,0])
        if torch.nonzero(target[-1,0]).size(0) > 0:
            #print('last index:',index)
            #print("last target:",target[-1,0].view(num_digits))
            loss = crossentropy(output, target[-1,0].view(num_digits))
            loss.backward()
            optimizer.step()
            print(f"Epoch: {epoch}, Train Batch: {batch_idx}, Train Loss: {loss.item():.6f}")

    # 验证
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(val_loader):
            inputs, target = batch
            inputs = inputs.view(-1, inputs.shape[2], inputs.shape[3])
            target = target.view(-1, target.shape[2], target.shape[3])
            inputs = inputs.to(device)
            target = target.to(device)
            index = 0
            loop = True
            while(loop):
                output = model(inputs)
                loop, inputs, target, index = replace_tokenizer_code(inputs,output,target)
            # 计算loss，直接让output和target就交叉熵就好，不需要用output来计算
            loss = crossentropy(output, target[-1,0].view(num_digits))
            val_loss += loss.item()
        # 计算平均loss，除以target的个数除以num_digits
        val_loss /= len(val_loader)
        print(f"Epoch: {epoch}, Val Loss: {val_loss:.6f}")



    # 保存最优模型
    if val_loss < best_val_loss:
        model.save(checkpoint_path)
        best_val_loss = val_loss

    # 保存当前模型状态
    model.save('./run/last.pt')

